{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql import SparkSession\n",
    "spark=SparkSession.builder.appName('random_forest').getOrCreate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "df=spark.read.csv('cars.csv',inferSchema=True,header=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5207, 6)\n"
     ]
    }
   ],
   "source": [
    "print((df.count(),len(df.columns)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------+------------------+-----------------+-----------------+------------------+------------------+\n",
      "|summary|              地区|       驾驶员年龄|             驾龄|      每年保养次数|          汽车类型|\n",
      "+-------+------------------+-----------------+-----------------+------------------+------------------+\n",
      "|  count|              5207|             5207|             5207|              5207|              5207|\n",
      "|   mean| 4.156135970808527|28.91357787593624|8.799404647589784|1.3578836182062608|2.4536201267524484|\n",
      "| stddev|0.9395952913420921|6.860914622077383|7.303439479306463|1.4217206625193524|0.8772697645768409|\n",
      "|    min|                 1|             17.5|              0.5|               0.0|                 1|\n",
      "|    max|                 5|             42.0|             23.0|               5.5|                 4|\n",
      "+-------+------------------+-----------------+-----------------+------------------+------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.describe().select('summary','地区','驾驶员年龄','驾龄','每年保养次数','汽车类型').show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- 地区: integer (nullable = true)\n",
      " |-- 驾驶员年龄: double (nullable = true)\n",
      " |-- 驾龄: double (nullable = true)\n",
      " |-- 每年保养次数: double (nullable = true)\n",
      " |-- 汽车类型: integer (nullable = true)\n",
      " |-- 故障: integer (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------+----+-----+\n",
      "|汽车类型|故障|count|\n",
      "+--------+----+-----+\n",
      "|       1|   0|  558|\n",
      "|       1|   1|  238|\n",
      "|       2|   0| 1334|\n",
      "|       2|   1|  481|\n",
      "|       3|   0| 1577|\n",
      "|       3|   1|  457|\n",
      "|       4|   0|  484|\n",
      "|       4|   1|   78|\n",
      "+--------+----+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('汽车类型','故障').count().orderBy('汽车类型','故障','count',ascending=True).show()  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------+----+-----+\n",
      "|每年保养次数|故障|count|\n",
      "+------------+----+-----+\n",
      "|         0.0|   0| 1752|\n",
      "|         0.0|   1|  285|\n",
      "|         1.0|   0|  680|\n",
      "|         1.0|   1|  271|\n",
      "|         2.0|   0|  810|\n",
      "|         2.0|   1|  374|\n",
      "|         3.0|   0|  422|\n",
      "|         3.0|   1|  205|\n",
      "|         4.0|   0|  172|\n",
      "|         4.0|   1|   77|\n",
      "|         5.5|   0|  117|\n",
      "|         5.5|   1|   42|\n",
      "+------------+----+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('每年保养次数','故障').count()\\\n",
    ".orderBy('每年保养次数','故障','count',ascending=True).show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import VectorAssembler  #VerctorAssembler 将多个列合并成向量列的特征转换器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_vec = VectorAssembler(inputCols=['地区', '驾驶员年龄', '驾龄', '每年保养次数', '汽车类型'], outputCol=\"features\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df_vec.transform(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- 地区: integer (nullable = true)\n",
      " |-- 驾驶员年龄: double (nullable = true)\n",
      " |-- 驾龄: double (nullable = true)\n",
      " |-- 每年保养次数: double (nullable = true)\n",
      " |-- 汽车类型: integer (nullable = true)\n",
      " |-- 故障: integer (nullable = true)\n",
      " |-- features: vector (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_df=df.select(['features','故障'])  #选择训练需要的模型列"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df,test_df=model_df.randomSplit([0.7,0.3]) # 训练数据和测试数据分为7比3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.classification import RandomForestClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+----+\n",
      "|            features|故障|\n",
      "+--------------------+----+\n",
      "|[1.0,22.0,2.5,0.0...|   1|\n",
      "|[1.0,22.0,2.5,1.0...|   1|\n",
      "|[1.0,22.0,2.5,1.0...|   1|\n",
      "|[1.0,22.0,2.5,1.0...|   0|\n",
      "|[1.0,22.0,2.5,1.0...|   0|\n",
      "|[1.0,22.0,2.5,1.0...|   1|\n",
      "|[1.0,22.0,2.5,1.0...|   1|\n",
      "|[1.0,27.0,2.5,0.0...|   0|\n",
      "|[1.0,27.0,2.5,0.0...|   1|\n",
      "|[1.0,27.0,2.5,0.0...|   1|\n",
      "+--------------------+----+\n",
      "only showing top 10 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "train_df.show(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_classifier=RandomForestClassifier(featuresCol='features',labelCol='故障',numTrees=30)\n",
    "rf_model = rf_classifier.fit(train_df) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_pred=rf_model.transform(test_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "评估每个属性的重要性:(5,[0,1,2,3,4],[0.5850305699488729,0.033799370247371376,0.23301277206058152,0.0795059584230612,0.06865132932011303])\n"
     ]
    }
   ],
   "source": [
    "print('{}{}'.format('评估每个属性的重要性:',rf_model.featureImportances)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------------------+----+---------------------------------------+----------------------------------------+----------+\n",
      "|features               |故障|rawPrediction                          |probability                             |prediction|\n",
      "+-----------------------+----+---------------------------------------+----------------------------------------+----------+\n",
      "|[1.0,17.5,0.5,0.0,2.0] |0   |[18.783571907130153,11.216428092869846]|[0.6261190635710051,0.37388093642899484]|0.0       |\n",
      "|[1.0,22.0,2.5,0.0,1.0] |1   |[13.564496325738299,16.435503674261703]|[0.45214987752460994,0.5478501224753901]|1.0       |\n",
      "|[1.0,27.0,6.0,1.0,2.0] |0   |[12.825942561920609,17.17405743807939] |[0.42753141873068695,0.572468581269313] |1.0       |\n",
      "|[1.0,27.0,6.0,1.0,3.0] |1   |[12.272550916929765,17.727449083070226]|[0.4090850305643256,0.5909149694356743] |1.0       |\n",
      "|[1.0,27.0,6.0,2.0,2.0] |1   |[13.251959727315288,16.748040272684708]|[0.4417319909105097,0.5582680090894904] |1.0       |\n",
      "|[1.0,27.0,9.0,1.0,3.0] |1   |[11.356614627371096,18.643385372628902]|[0.3785538209123699,0.6214461790876301] |1.0       |\n",
      "|[1.0,27.0,9.0,4.0,1.0] |0   |[12.378159288441315,17.621840711558683]|[0.41260530961471054,0.5873946903852895]|1.0       |\n",
      "|[1.0,32.0,13.0,0.0,3.0]|1   |[12.793652237134587,17.20634776286541] |[0.42645507457115295,0.573544925428847] |1.0       |\n",
      "|[1.0,32.0,13.0,3.0,3.0]|1   |[14.32007036806808,15.67992963193192]  |[0.4773356789356027,0.5226643210643973] |1.0       |\n",
      "|[1.0,32.0,16.5,2.0,2.0]|1   |[10.547587713858405,19.452412286141588]|[0.3515862571286136,0.6484137428713864] |1.0       |\n",
      "+-----------------------+----+---------------------------------------+----------------------------------------+----------+\n",
      "only showing top 10 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "rf_pred.show(10,False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.evaluation import BinaryClassificationEvaluator #对二进制分类的评估器,它期望两个输入列:原始预测值和标签"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_bin=BinaryClassificationEvaluator(labelCol='故障').evaluate(rf_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "使用二元分类评估其器的评估结果:0.734692458171\n"
     ]
    }
   ],
   "source": [
    "print('{}{}'.format('使用二元分类评估其器的评估结果:',rf_bin))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SparseVector(5, {0: 0.585, 1: 0.0338, 2: 0.233, 3: 0.0795, 4: 0.0687})"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rf_model.featureImportances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "import_feature=df.schema[\"features\"].metadata[\"ml_attr\"][\"attrs\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"numeric\": [{\"name\": \"地区\", \"idx\": 0}, {\"name\": \"驾驶员年龄\", \"idx\": 1}, {\"name\": \"驾龄\", \"idx\": 2}, {\"name\": \"每年保养次数\", \"idx\": 3}, {\"name\": \"汽车类型\", \"idx\": 4}]}\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "print json.dumps(import_feature,encoding='utf-8',ensure_ascii=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{u'numeric': [{u'idx': 0, u'name': u'\\u5730\\u533a'},\n",
       "  {u'idx': 1, u'name': u'\\u9a7e\\u9a76\\u5458\\u5e74\\u9f84'},\n",
       "  {u'idx': 2, u'name': u'\\u9a7e\\u9f84'},\n",
       "  {u'idx': 3, u'name': u'\\u6bcf\\u5e74\\u4fdd\\u517b\\u6b21\\u6570'},\n",
       "  {u'idx': 4, u'name': u'\\u6c7d\\u8f66\\u7c7b\\u578b'}]}"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
